
import numpy as np
from typing import Dict, List, Tuple
from collections import Counter

def dirichlet(X, targets, num_users: int, alpha: float, least_samples: int = 20) -> Tuple[List[List[int]], Dict]:
    """
    https://github.com/KarhouTam/FL-bench/blob/master/data/utils/schemes/dirichlet.py
    Dirichlet: Refers to Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification (FedAvgM). 
    Dataset would be splitted according to Dir(a). Smaller a means stronger label heterogeneity.

    --alpha or -a: The parameter for controlling intensity of label heterogeneity.
    --least_samples or -ls: The parameter for defining the minimum number of samples each client would be distributed. 
    A small --least_samples along with small --alpha or big --client_num might considerablely prolong the partition.
    """
    targets = targets.reshape(-1)
    try:
        label_num = len(set(targets))
    except:
        label_num = len(np.unique(targets))
    min_size_set=[0]*label_num
    min_size = 0
    stats = {}
    partition = {"separation": None, "data_indices": None}

    targets_numpy = np.array(targets, dtype=np.int32)
    data_idx_for_each_label = [
        np.where(targets_numpy == i)[0] for i in range(label_num)
    ]

    unique_vals, counts = np.unique(targets, return_counts=True)
    num_unique = unique_vals.shape[0]  # 或 len(unique_vals)
    element_counts = dict(zip(unique_vals, counts))

    # print("number of unique element: ", num_unique)
    # print("statistics: ", element_counts)

    if label_num == 2:
        while min_size < least_samples:
            distrib_lst = []
            data_indices = [[] for _ in range(num_users)]
            for k in range(label_num):
                np.random.shuffle(data_idx_for_each_label[k])
                distrib = np.random.dirichlet(np.repeat(alpha*2, num_users))
                distrib = np.array(
                    [
                        p * (len(idx_j) < len(targets_numpy) / num_users)
                        for p, idx_j in zip(distrib, data_indices)
                    ]
                )
                distrib = distrib / distrib.sum()
                distrib_cumsum = (np.cumsum(distrib) * len(data_idx_for_each_label[k])).astype(int)[:-1]
                data_indices = [
                    np.concatenate((idx_j, idx.tolist())).astype(np.int64)
                    for idx_j, idx in zip(
                        data_indices, np.split(data_idx_for_each_label[k], distrib_cumsum)
                    )
                ]
                distrib_lst.append(distrib)
                min_size_set[k] = min([len(np.intersect1d(idx_j, data_idx_for_each_label[k])) for idx_j in data_indices])
            min_size = min(min_size_set)
        
        for i in range(num_users):
            stats[i] = {"x": None, "y": None}
            stats[i]["x"] = len(targets_numpy[data_indices[i]])
            stats[i]["y"] = Counter(targets_numpy[data_indices[i]].tolist())
        
        num_samples = np.array(list(map(lambda stat_i: stat_i["x"], stats.values())))
        stats["sample per client"] = {
            "std": num_samples.mean(),
            "stddev": num_samples.std(),
        }

    elif label_num > 2:
        idx_per_label = [np.where(targets == k)[0] for k in range(label_num)]

        partition = {"data_indices": None, "separation": None}
        stats = {}

        while True:
            data_indices = [[] for _ in range(num_users)]
            distrib_lst = []

            for k in range(label_num):
                np.random.shuffle(idx_per_label[k])
                props = np.random.dirichlet([alpha*2] * num_users)
                counts = (props * len(idx_per_label[k])).astype(int)
                diff = len(idx_per_label[k]) - counts.sum()
                if diff > 0:
                    counts[:diff] += 1
                splits = np.split(idx_per_label[k], np.cumsum(counts)[:-1])
                for u in range(num_users):
                    data_indices[u].extend(splits[u].tolist())
                distrib_lst.append(props)

            sizes = [len(data_indices[u]) for u in range(num_users)]
            if min(sizes) >= least_samples:
                break

        for u in range(num_users):
            cnt = Counter(targets[data_indices[u]].tolist())
            stats[u] = {"x": len(data_indices[u]), "y": cnt}

        num_samples = np.array([stats[u]["x"] for u in range(num_users)])
        stats["sample per client"] = {
            "std": num_samples.mean(),
            "stddev": num_samples.std(),
        }

    partition["data_indices"] = data_indices
    partition["separation"]   = distrib_lst
    return partition, stats

def heterogene(X, S: np.ndarray, Y: np.ndarray, n_clients, gamma_range=(0.15, 0.85)):


    gammas = np.random.uniform(gamma_range[0], gamma_range[1], size=n_clients)  # 10 个客户端

    # 1) 按 (S,Y) 四类收集索引并打乱
    idx = {
        (i, j): np.where((S == i) & (Y == j))[0].tolist()
        for i in (0, 1) for j in (0, 1)
    }
    for key in idx:
        np.random.shuffle(idx[key])
    
    # 2) 计算每类样本总数
    N_ij = {key: len(idx[key]) for key in idx}
    
    # 3) 计算每个客户端在各类下的权重 w_k^{(i,j)}
    weights = {
        k: {
            (0, 0): gammas[k],
            (1, 1): gammas[k],
            (1, 0): 1 - gammas[k],
            (0, 1): 1 - gammas[k],
        }
        for k in range(n_clients)
    }
    
    # 4) 计算各类权重总和
    W = {
        key: sum(weights[k][key] for k in range(n_clients))
        for key in idx
    }
    
    # 5) 首次分配：取 floor
    counts = {
        k: {key: int(np.floor(weights[k][key] / W[key] * N_ij[key]))
            for key in idx}
        for k in range(n_clients)
    }
    
    # 6) 处理余数，保证 sum_k counts[k][key] == N_ij[key]
    for key in idx:
        assigned = sum(counts[k][key] for k in range(n_clients))
        rem = N_ij[key] - assigned
        # 简单循环分配余数
        for k in range(rem):
            counts[k][key] += 1
    
    # 7) 按 counts 切片构建最终 partitions
    partitions = [[] for _ in range(n_clients)]
    for key in idx:
        start = 0
        for k in range(n_clients):
            c = counts[k][key]
            partitions[k].extend(idx[key][start:start+c])
            start += c
    
    return partitions
